import copy
import os
import sys
from collections import OrderedDict

import torch
import numpy as np
from numpy import random
path = os.getcwd() #current path
sys.path.append(os.path.abspath(os.path.join(path, os.pardir))) #import the parent directory

from model import binarization

class Server():

    def __init__(self, args, model):
        self.clients_list = np.arange(args.num_clients)
        self.args = args
        if self.args.mask:
            self.initial_model = model
            self.global_thresholds = OrderedDict()
            self.set_global_thresholds()
            self.global_difference = OrderedDict()
        else:
            self.global_model = copy.deepcopy(model)

    def set_global_thresholds(self):
        # for name, layer in self.initial_model.named_modules():
        #     if isinstance(layer, binarization.MaskedConv2d) or isinstance(layer, binarization.MaskedMLP):
        #         self.global_thresholds[name] = layer.threshold
        for name, param in self.initial_model.named_parameters():
            if name.find('threshold')!= -1:
                self.global_thresholds[name] = param


    def sample_clients(self):
        """
        Return: array of integers, which corresponds to the indices of sampled deviecs
        """
        sampling_set = np.random.choice(self.args.num_clients, self.args.schedulingsize, replace = False)

        return sampling_set
    
    def broadcast(self, Clients_list, Clients_list_idx = None):
        """
        Input: a list of Client class
        Flow: Set the current global thresholds to every client
        """
        if self.args.mask:
            for client in Clients_list:
                with torch.no_grad():
                    for name, params in client.model.named_parameters():
                        if name.find('threshold') != -1:
                            client.model.state_dict()[name].copy_(self.global_thresholds[name])
        else: #FedAvg 
            #input list of clients index
            for client_idx in Clients_list_idx:
                with torch.no_grad():
                    Clients_list[client_idx].model.load_state_dict(copy.deepcopy(self.global_model.state_dict()))

    def aggregation(self, Clients_list, sampling_set):
        """
        Input: sampling_set: array of integers, which corresponds to the indices of sampled devices and a list of Client class
        Flow: aggregate the updated threholds in the sampling set
        """
        if self.args.mask:
            threshold_dict = OrderedDict()
            for i, client in enumerate(sampling_set):
                local_model = Clients_list[client].model.state_dict()
                for name, params in Clients_list[client].model.named_parameters():
                    if name.find('threshold') != -1:
                        if i == 0:
                            threshold_dict[name] = params * 1/self.args.schedulingsize
                        else:
                            threshold_dict[name] += params *1/self.args.schedulingsize

            for key in self.global_thresholds:
                self.global_difference[key] = threshold_dict[key] - self.global_thresholds[key]
            
            self.global_thresholds = threshold_dict
            

        else: #FedAvg
            weight_dict = OrderedDict()
            for i, client in enumerate(sampling_set):
                local_model = Clients_list[client].model.state_dict()
                if i == 0:
                    for key in local_model.keys():
                        weight_dict[key] = local_model[key] * 1/self.args.schedulingsize
                else:
                    for key in local_model.keys():
                        weight_dict[key] += local_model[key] *1/self.args.schedulingsize

            self.global_model.load_state_dict(weight_dict)


    def increase_thresholds(self):
        for key in self.global_thresholds:
            # refined_threshold = torch.maximum(self.global_thresholds[key], self.global_thresholds[key] * 0)
            # self.global_thresholds[key] += self.args.th_coeff * self.args.learning_rate*(torch.exp(-refined_threshold  )  )
            self.global_thresholds[key] += self.args.th_coeff * self.args.learning_rate*(torch.exp(-self.global_thresholds[key] )  )